{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset-bpe.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "EOS = 2\n",
    "GO = 1\n",
    "vocab_size = 32000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Y = [i + [2] for i in train_Y]\n",
    "test_Y = [i + [2] for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "BASE_PARAMS = defaultdict(\n",
    "    lambda: None,  # Set default value to None.\n",
    "\n",
    "    # Input params\n",
    "    default_batch_size=2048,  # Maximum number of tokens per batch of examples.\n",
    "    default_batch_size_tpu=32768,\n",
    "    max_length=256,  # Maximum number of tokens per example.\n",
    "\n",
    "    # Model params\n",
    "    initializer_gain=1.0,  # Used in trainable variable initialization.\n",
    "    vocab_size=vocab_size,  # Number of tokens defined in the vocabulary file.\n",
    "    hidden_size=512,  # Model dimension in the hidden layers.\n",
    "    num_hidden_layers=6,  # Number of layers in the encoder and decoder stacks.\n",
    "    num_heads=8,  # Number of heads to use in multi-headed attention.\n",
    "    filter_size=2048,  # Inner layer dimension in the feedforward network.\n",
    "\n",
    "    # Dropout values (only used when training)\n",
    "    layer_postprocess_dropout=0.1,\n",
    "    attention_dropout=0.1,\n",
    "    relu_dropout=0.1,\n",
    "\n",
    "    # Training params\n",
    "    label_smoothing=0.1,\n",
    "    learning_rate=2.0,\n",
    "    learning_rate_decay_rate=1.0,\n",
    "    learning_rate_warmup_steps=16000,\n",
    "\n",
    "    # Optimizer params\n",
    "    optimizer_adam_beta1=0.9,\n",
    "    optimizer_adam_beta2=0.997,\n",
    "    optimizer_adam_epsilon=1e-09,\n",
    "\n",
    "    # Default prediction params\n",
    "    extra_decode_length=50,\n",
    "    beam_size=4,\n",
    "    alpha=0.6,  # used to calculate length normalization in beam search\n",
    "\n",
    "    # TPU specific parameters\n",
    "    use_tpu=False,\n",
    "    static_batch=False,\n",
    "    allow_ffn_pad=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/translation/transformer/embedding_layer.py:24: The name tf.layers.Layer is deprecated. Please use tf.compat.v1.layers.Layer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tensor2tensor.utils import beam_search\n",
    "from transformer import embedding_layer\n",
    "from transformer.transformer import EncoderStack\n",
    "from transformer import model_utils\n",
    "\n",
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, num_layers, train = True, learning_rate = 1e-4):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(\n",
    "            BASE_PARAMS[\"vocab_size\"], BASE_PARAMS[\"hidden_size\"],\n",
    "            method=\"gather\")\n",
    "        self.encoder_stack = EncoderStack(BASE_PARAMS, train)\n",
    "        with tf.name_scope(\"encode\"):\n",
    "          # Prepare inputs to the layer stack by adding positional encodings and\n",
    "          # applying dropout.\n",
    "            embedded_inputs = self.embedding_softmax_layer(self.X)\n",
    "            inputs_padding = model_utils.get_padding(self.X)\n",
    "            attention_bias = model_utils.get_padding_bias(self.X)\n",
    "\n",
    "            with tf.name_scope(\"add_pos_encoding\"):\n",
    "                length = tf.shape(embedded_inputs)[1]\n",
    "                pos_encoding = model_utils.get_position_encoding(\n",
    "                    length, BASE_PARAMS[\"hidden_size\"])\n",
    "                encoder_inputs = embedded_inputs + pos_encoding\n",
    "\n",
    "            if train:\n",
    "                encoder_inputs = tf.nn.dropout(\n",
    "                    encoder_inputs, 1 - BASE_PARAMS[\"layer_postprocess_dropout\"])\n",
    "\n",
    "            self.encoded = self.encoder_stack(encoder_inputs, attention_bias, inputs_padding)\n",
    "            print(self.encoded)\n",
    "            \n",
    "        first_token_tensor = tf.squeeze(\n",
    "            self.encoded[:, 0:1, :], axis = 1\n",
    "        )\n",
    "        c = tf.layers.dense(\n",
    "            first_token_tensor,\n",
    "            BASE_PARAMS[\"hidden_size\"],\n",
    "            activation = tf.tanh,\n",
    "        )\n",
    "        h = tf.layers.dense(\n",
    "            first_token_tensor,\n",
    "            BASE_PARAMS[\"hidden_size\"],\n",
    "            activation = tf.tanh,\n",
    "        )\n",
    "\n",
    "        def cells(reuse=False):\n",
    "            return tf.nn.rnn_cell.LSTMCell(BASE_PARAMS[\"hidden_size\"],initializer=tf.orthogonal_initializer(),reuse=reuse)\n",
    "        \n",
    "        lstm_state = tf.nn.rnn_cell.LSTMStateTuple(c=c, h=h)\n",
    "        \n",
    "        encoder_state = tuple([lstm_state] * num_layers)\n",
    "        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)])\n",
    "        \n",
    "        embedding = self.embedding_softmax_layer.shared_weights\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        dense = tf.layers.Dense(vocab_size)\n",
    "        \n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(embedding, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = encoder_state,\n",
    "                output_layer = dense)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        \n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = embedding,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = encoder_state,\n",
    "                output_layer = dense)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.fast_result = predicting_decoder_output.sample_id\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        \n",
    "        xentropy, weights = utils.padded_cross_entropy_loss(\n",
    "            self.training_logits, self.Y, BASE_PARAMS[\"label_smoothing\"], BASE_PARAMS[\"vocab_size\"])\n",
    "        self.cost = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)\n",
    "        \n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n",
      "WARNING:tensorflow:From /home/husein/translation/transformer/attention_layer.py:39: The name tf.layers.Dense is deprecated. Please use tf.compat.v1.layers.Dense instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/translation/transformer/embedding_layer.py:48: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/translation/transformer/embedding_layer.py:48: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/translation/transformer/embedding_layer.py:51: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/translation/transformer/embedding_layer.py:70: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n",
      "WARNING:tensorflow:From <ipython-input-9-84a431fde5c5>:39: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/autograph/converters/directives.py:119: The name tf.rsqrt is deprecated. Please use tf.math.rsqrt instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/translation/transformer/ffn_layer.py:65: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/husein/translation/transformer/ffn_layer.py:65: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n",
      "Tensor(\"encode/encoder_stack/layer_normalization/add_1:0\", shape=(?, ?, 512), dtype=float32)\n",
      "WARNING:tensorflow:From <ipython-input-9-84a431fde5c5>:50: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:From <ipython-input-9-84a431fde5c5>:59: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:From <ipython-input-9-84a431fde5c5>:64: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:958: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:962: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From /home/husein/translation/transformer/utils.py:82: The name tf.log is deprecated. Please use tf.math.log instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator(num_layers = 2)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[ 1584,  8487,  8487,  8487,  8487,  8487, 21053, 31856, 31856,\n",
       "         31856, 31856, 31856, 31856, 31856, 31856, 31856,  9791,  9791,\n",
       "          9791,  9791,  9791,  9791,  2356, 13770, 13770, 13770, 13770,\n",
       "         13770, 13770,  3685,  3685,  3685,  1210,  1210, 25848, 25848,\n",
       "         25848,  5243,  5243,  5243,  3273, 11782, 11782, 11782, 10310,\n",
       "         10310, 23115, 23115, 23115, 12151, 12151, 12151, 12151, 12151,\n",
       "         11810,  3376,  7333,  7333,  7333, 14306, 14306, 14306, 14306,\n",
       "         22549, 22549, 22549,  6729, 16134, 16134, 16134, 16134,  3107],\n",
       "        [30245, 30245, 30245, 19217, 12923, 12923, 12923, 12923, 14609,\n",
       "         14609,  6175,  6175,  6175,  6175,  6175,  6175,  6175,  6175,\n",
       "          7999, 10259, 10259, 10259, 10259, 10259, 10259, 27756,  6938,\n",
       "          6938,  6938,  6938,  6938, 17095, 17095, 17095, 29059, 29059,\n",
       "         13342, 12696, 12696, 21499,  7348,  7348,  7348, 11768, 11768,\n",
       "         11768, 11768, 28451, 28451, 20921, 20921, 13516, 13516,  5371,\n",
       "          5371,  5371,  5371,  5371,  7513, 19647,  6604, 19647, 20287,\n",
       "         20287, 20287, 20287, 20287,   914,   914,   914,  6653,  6653],\n",
       "        [ 4258,  4258, 16237, 29131, 29131, 29131, 29131, 29131, 29131,\n",
       "          1149,  1149,  1149,  1149, 23958, 23958, 23958, 23958, 23958,\n",
       "         23958, 23958, 23958, 23958, 15456, 15456, 15456,  7152,  7152,\n",
       "         26405, 26405, 26405,  4060,  8118,  8118, 24555, 24555, 16495,\n",
       "         16495, 11441,    33, 16915, 24312, 24312,  2105,  2105, 14821,\n",
       "         14821, 23826, 23826, 24312, 24312,  8306,  8306,  8306,  8306,\n",
       "         10030,  2392,  2392, 19341, 19341,  4663,  4663,  4663,  4663,\n",
       "          4663,  4663,  4663,  9106, 30281, 18878, 18878, 23308, 23308],\n",
       "        [27249, 18266, 18266, 18266, 18266, 20455, 20455, 20455, 20455,\n",
       "         20455,  6455, 25248, 25248, 25248, 25248, 25248, 25248, 25248,\n",
       "         25248, 25248, 25248, 19283, 19283, 30171, 27579, 27579, 27579,\n",
       "         30171, 23245, 27754, 27754,  5273,  5273,  5273, 23479, 19563,\n",
       "          7740,  7740, 23479,  2790, 23521, 23521, 13976, 13976, 29844,\n",
       "         29844, 29844, 12102, 12102, 12102, 12072, 28510,  6810,  9048,\n",
       "          9048,  9048,  9048, 20050, 16995, 16995, 19365, 19365, 19365,\n",
       "         19365, 19365, 19365,  8495,  2949, 23197, 23197, 23197, 23197],\n",
       "        [26977, 26977, 26977, 27925, 27925, 27925, 27925, 27925, 25367,\n",
       "         25367, 25367, 25367, 25367, 25367,  6083,  6083,  6083,  6083,\n",
       "         27708,  8580,  8580,  8580, 15456, 15456, 15456,  9243,  9243,\n",
       "          9243, 24813, 24813, 29594,  6149, 13929, 19567, 19567, 19567,\n",
       "          5595,  5595,  5595,  5595,  5595,  8715,  6287,  3730, 14061,\n",
       "          1626,  1626,  1626, 21497, 21497, 21497, 20521,  3839, 27762,\n",
       "         27762, 21263, 21263, 21263, 12969, 13054, 13054, 13054, 31077,\n",
       "         31077, 31077, 15947, 15947, 31077, 15052, 15052, 15947, 15947],\n",
       "        [ 3632,  4258,  3594, 14878, 14878, 14878, 14878, 14878,  3784,\n",
       "          3784,  8359,  8359,  8359, 17454, 17454, 14987, 14987, 14987,\n",
       "         14987, 14987, 14987, 14987, 14987, 14987, 14987, 14987, 14987,\n",
       "         14987, 14987, 14987, 14987,  4538, 27189, 27189,  2660, 14337,\n",
       "         14337, 14337, 13444, 19890, 19890, 19890, 19890,  9642,  6193,\n",
       "          6193,  6193,  3441,  3909, 28706, 15900, 15900, 31597, 31597,\n",
       "         31597, 27788, 27788,  3309, 13110, 13110, 23568,   757, 30792,\n",
       "         30792, 23568, 13075, 13075, 13075, 28755, 28755, 28755, 30949],\n",
       "        [ 6019,  7271,  7271,  7638, 16890, 16890, 16890,  9814, 17932,\n",
       "         17932, 17932, 25186, 25186, 25186, 25186, 13216, 13216, 20984,\n",
       "         20984, 20984, 20984, 24157, 24157, 24157, 28151, 28151, 10582,\n",
       "         15978, 15978, 15978, 15978, 15978,  4937,  4937, 15088, 20678,\n",
       "          7801,  7801, 20678,  4012,  4012,  4012,  9844, 31011, 31011,\n",
       "         31011, 31011, 16000, 16000,  9347, 29847, 29847, 17779,  4612,\n",
       "          4612,  4612, 11597, 11597, 28698, 28698, 28698,  5712, 14933,\n",
       "         14933, 14933, 20523, 15927, 15927, 15927, 27889, 27889, 19547],\n",
       "        [10237, 10237,  8219, 28485, 28485, 28485, 28485, 28485, 28485,\n",
       "         28485, 28485, 28485, 11708, 11708, 11708, 11708, 11708,  9938,\n",
       "          9938,  9938,  9938,  9938, 22653, 22653, 16327, 16327, 16327,\n",
       "         16582, 16582, 29607, 29607, 29607, 19236,  9760,  9760,  9760,\n",
       "         29167, 29167, 29167, 26223, 26223, 26223, 26611, 26223, 26223,\n",
       "         14544, 25195, 25195, 25195,  9261,  6791,  6791, 15413, 15413,\n",
       "         15413, 15413,  1189, 16323, 16323, 15107, 15107, 15107,  9947,\n",
       "          9947,  6063, 17819, 30453, 30453,  4128, 29946, 11952, 11952],\n",
       "        [28869, 28869, 18429,  8282,  8282,  8282,  8282,  8282,  3378,\n",
       "          3378,  3378,  3378,  3378,  3378,  4463,  4463,  4463,  4463,\n",
       "         24465, 24158, 24158, 24158, 25390,  9718,  9718,  9718, 25718,\n",
       "         25718, 25718, 14504, 21552, 21552, 21552, 20609, 20609,  5059,\n",
       "          5059, 18574, 18574, 18574, 26819, 31704,  6777,  6777,  6777,\n",
       "          6777, 21266, 21266, 21266, 21266, 21266, 17566, 17566, 16164,\n",
       "         16164, 12992, 12992,  3800, 22274, 22274, 22274, 22055, 22055,\n",
       "         22055, 30729, 30729, 23667,  4393,  4393,  4393,  4393,  4393],\n",
       "        [14912,  1172,  1172,  1172,  1172,  9983, 29068, 29068, 29068,\n",
       "         29068, 29068, 29068, 29068, 29068, 29068, 20722, 20722, 24901,\n",
       "         24901, 24901, 24901, 24901, 24901, 24901, 18260, 18260, 18260,\n",
       "         18260, 18260, 23992, 25418, 25418, 25418,  4895, 20353, 20353,\n",
       "         20353, 20353, 26525, 26525, 26525, 26525,  5908, 11914, 11914,\n",
       "         11739, 11739, 11739, 30829, 30829,  6703,  7126, 10264, 10264,\n",
       "         10264, 18178, 18178, 18178, 18178, 18178, 30470, 30470, 22539,\n",
       "         23507, 23507,  5938,  5938,  5938,  6034,  5546, 24574, 24574]],\n",
       "       dtype=int32), 9.010772, 0.0]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x = pad_sequences(train_X[:10], padding='post')\n",
    "batch_y = pad_sequences(train_Y[:10], padding='post')\n",
    "\n",
    "sess.run([model.fast_result, model.cost, model.accuracy], \n",
    "         feed_dict = {model.X: batch_x, model.Y: batch_y})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:17<00:00,  2.53it/s, accuracy=0.103, cost=5.99] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.42it/s, accuracy=0.129, cost=5.64] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg loss 6.389846, training avg acc 0.074503\n",
      "epoch 1, testing avg loss 5.942791, testing avg acc 0.106623\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:10<00:00,  2.56it/s, accuracy=0.149, cost=5.34]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.71it/s, accuracy=0.177, cost=4.86]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg loss 5.591633, training avg acc 0.132945\n",
      "epoch 2, testing avg loss 5.276812, testing avg acc 0.157447\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:09<00:00,  2.56it/s, accuracy=0.183, cost=4.95]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.87it/s, accuracy=0.188, cost=4.46]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg loss 5.062870, training avg acc 0.175475\n",
      "epoch 3, testing avg loss 4.879865, testing avg acc 0.191697\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:54<00:00,  2.63it/s, accuracy=0.205, cost=4.68]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.83it/s, accuracy=0.22, cost=4.21] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg loss 4.727176, training avg acc 0.203113\n",
      "epoch 4, testing avg loss 4.619950, testing avg acc 0.211189\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.217, cost=4.48]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.67it/s, accuracy=0.231, cost=4.03]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg loss 4.493851, training avg acc 0.221366\n",
      "epoch 5, testing avg loss 4.435090, testing avg acc 0.226009\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.235, cost=4.31]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.70it/s, accuracy=0.247, cost=3.89]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg loss 4.316304, training avg acc 0.235867\n",
      "epoch 6, testing avg loss 4.293239, testing avg acc 0.237922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.244, cost=4.16]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.70it/s, accuracy=0.28, cost=3.75] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg loss 4.168743, training avg acc 0.249200\n",
      "epoch 7, testing avg loss 4.174297, testing avg acc 0.248792\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.256, cost=4.03]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.70it/s, accuracy=0.29, cost=3.61] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg loss 4.039792, training avg acc 0.261463\n",
      "epoch 8, testing avg loss 4.072978, testing avg acc 0.258610\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.274, cost=3.91]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.67it/s, accuracy=0.269, cost=3.57]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg loss 3.924547, training avg acc 0.273151\n",
      "epoch 9, testing avg loss 3.989358, testing avg acc 0.266262\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.285, cost=3.79]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.70it/s, accuracy=0.317, cost=3.49]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg loss 3.820036, training avg acc 0.284692\n",
      "epoch 10, testing avg loss 3.908415, testing avg acc 0.276933\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:10<00:00,  2.56it/s, accuracy=0.305, cost=3.68]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.67it/s, accuracy=0.323, cost=3.44]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg loss 3.724962, training avg acc 0.295350\n",
      "epoch 11, testing avg loss 3.838239, testing avg acc 0.283988\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:53<00:00,  2.63it/s, accuracy=0.304, cost=3.59]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.61it/s, accuracy=0.339, cost=3.38]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg loss 3.636323, training avg acc 0.305995\n",
      "epoch 12, testing avg loss 3.777020, testing avg acc 0.291484\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:38<00:00,  2.70it/s, accuracy=0.318, cost=3.5] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.71it/s, accuracy=0.349, cost=3.33]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg loss 3.555310, training avg acc 0.315882\n",
      "epoch 13, testing avg loss 3.729742, testing avg acc 0.297013\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.336, cost=3.41]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.69it/s, accuracy=0.333, cost=3.33]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg loss 3.479341, training avg acc 0.325446\n",
      "epoch 14, testing avg loss 3.685001, testing avg acc 0.301740\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.347, cost=3.32]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.70it/s, accuracy=0.355, cost=3.25]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg loss 3.407731, training avg acc 0.334744\n",
      "epoch 15, testing avg loss 3.641921, testing avg acc 0.308633\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.355, cost=3.25]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.69it/s, accuracy=0.36, cost=3.23] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg loss 3.342026, training avg acc 0.343555\n",
      "epoch 16, testing avg loss 3.596954, testing avg acc 0.314400\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:37<00:00,  2.71it/s, accuracy=0.368, cost=3.16]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.69it/s, accuracy=0.36, cost=3.21] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg loss 3.278260, training avg acc 0.351899\n",
      "epoch 17, testing avg loss 3.560529, testing avg acc 0.319729\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop:  82%|████████▏ | 1283/1563 [07:52<01:38,  2.84it/s, accuracy=0.372, cost=3.11]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "minibatch loop: 100%|██████████| 1563/1563 [09:39<00:00,  2.70it/s, accuracy=0.384, cost=3.09]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.49it/s, accuracy=0.36, cost=3.22] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg loss 3.217587, training avg acc 0.360333\n",
      "epoch 18, testing avg loss 3.535361, testing avg acc 0.322503\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:40<00:00,  2.69it/s, accuracy=0.398, cost=3.01]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:07<00:00,  5.70it/s, accuracy=0.355, cost=3.19]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg loss 3.160061, training avg acc 0.368202\n",
      "epoch 19, testing avg loss 3.505346, testing avg acc 0.326451\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [09:49<00:00,  2.65it/s, accuracy=0.387, cost=2.94]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:17<00:00,  2.29it/s, accuracy=0.371, cost=3.15]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg loss 3.106020, training avg acc 0.375853\n",
      "epoch 20, testing avg loss 3.478847, testing avg acc 0.330660\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = pad_sequences(train_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(train_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(test_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import bleu_hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 40/40 [00:28<00:00,  1.39it/s]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for i in tqdm.tqdm(range(0, len(test_X), batch_size)):\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "    feed = {model.X: batch_x}\n",
    "    p = sess.run(model.fast_result,feed_dict = feed)\n",
    "    result = []\n",
    "    for row in p:\n",
    "        result.append([i for i in row if i > 3])\n",
    "    results.extend(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "rights = []\n",
    "for r in test_Y:\n",
    "    rights.append([i for i in r if i > 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.049064703"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bleu_hook.compute_bleu(reference_corpus = rights,\n",
    "                       translation_corpus = results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
